import copy
import os.path

from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
import torch
import pickle
from tqdm import tqdm
from collections import OrderedDict
import random
import torch.nn.functional as F
import torch.optim as optim

class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        super(Server, self).__init__(option, model, clients, data_loader, device)
        # unlearn config
        # different from federaser, the update_his only contains updates of the target client
        # self.file_path = self.u_folder + self.save_name + '.pkl'
        self.teacher_model = copy.deepcopy(self.model)
        self.temperature = 2.0
        self.alpha = 0.5
        with open(os.path.join(os.path.dirname(self.save_folder), 'fedavg', f"pretrained_history_fedkdu", f"kdu_his_s{self.option['split_num']}_c{self.option['class_num']}_bd_{self.bd}.pkl"), 'rb') as f:
           self.update_his = pickle.load(f)
        self.u_rounds = 1
    def unlearn_iterate(self):
        for cid in self.unlearn_clients_id:
            updates = []
            for up in self.update_his[cid]:
                updates.append(self.load_param(copy.deepcopy(self.model), up))
            target_up = fmodule._model_sum(updates) / len(self.clients)
            self.model = self.model - target_up

    def pt_iterate(self):
        for param in self.teacher_model.parameters():
            param.requires_grad = False
        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        for c in self.unlearn_clients:
            for step, (batch_x, batch_y) in enumerate(c.train_data):
                if step > 4:
                    break
                optimizer.zero_grad()
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)
                teacher_outputs = self.teacher_model(batch_x)
                outputs = self.model(batch_x)

                teacher_soft_targets = F.softmax(teacher_outputs / self.temperature, dim=1)
                soft_targets = F.log_softmax(outputs / self.temperature, dim=1)

                loss_kl = F.kl_div(soft_targets, teacher_soft_targets, reduction='batchmean') * (self.temperature ** 2)
                loss_ce = F.cross_entropy(outputs, batch_y)
                loss = self.alpha * loss_kl + (1.0 - self.alpha) * loss_ce
                loss.backward()
                optimizer.step()
            return

    def get_param(self, m):
        return {k: v.cpu() for k, v in m.state_dict().items()}

    def load_param(self, m, param_dict):
        state_dict_on_gpu = {k: v.to(self.device) for k, v in param_dict.items()}
        m.load_state_dict(state_dict_on_gpu)
        m = m.to(self.device)
        return m




class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)
